Skip to content

[MoE Refactor][14/N] Clean Up FI Quant Config Smuggling#31593

Merged
robertgshaw2-redhat merged 62 commits intomainfrom
fix-flashinfer-experts-quant-config-hack
Jan 6, 2026
Merged

[MoE Refactor][14/N] Clean Up FI Quant Config Smuggling#31593
robertgshaw2-redhat merged 62 commits intomainfrom
fix-flashinfer-experts-quant-config-hack

Conversation

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat commented Jan 1, 2026

Purpose

  • moe refactor --- flashinfer smuggles scales for certain kernels via nvpf4 global scales --- clean up
  • potentially fixes issues with flashinfer per tensor for non-modelopt --- it avoids crashing mixtral but we still have 0% accuracy on mixtral. Will address this in another PR.

Test Plan

# autofp8
MODEL_BLOCK := "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
# MODEL_TENSOR := "amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV"

# modelopt
MODEL_TENSOR := "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"

GPUS := "2"
PORT := "8001"

# sm90
launch_cutlass_block:
	VLLM_USE_DEEP_GEMM=0 VLLM_USE_FLASHINFER_MOE_FP8=1 VLLM_FLASHINFER_MOE_BACKEND=throughput vllm serve {{MODEL_BLOCK}} -tp {{GPUS}} --port {{PORT}}

# sm90
launch_cutlass_tensor:
	VLLM_USE_DEEP_GEMM=0 VLLM_USE_FLASHINFER_MOE_FP8=1 VLLM_FLASHINFER_MOE_BACKEND=throughput vllm serve {{MODEL_TENSOR}} -tp {{GPUS}} --port {{PORT}} --max-model-len 8192

# sm100
launch_trtllm_block:
	VLLM_USE_DEEP_GEMM=0 VLLM_USE_FLASHINFER_MOE_FP8=1 VLLM_FLASHINFER_MOE_BACKEND=latency chg run --gpus {{GPUS}} -- vllm serve {{MODEL_BLOCK}} -tp {{GPUS}}

# sm100
launch_trtllm_tensor:
	VLLM_USE_DEEP_GEMM=0 VLLM_USE_FLASHINFER_MOE_FP8=1 VLLM_FLASHINFER_MOE_BACKEND=latency  chg run --gpus {{GPUS}} -- vllm serve {{MODEL_TENSOR}} -tp {{GPUS}} --max-model-len 8192

eval_block:
	lm_eval \
		--model local-completions \
		--tasks gsm8k \
		--model_args "model={{MODEL_BLOCK}},base_url=http://localhost:{{PORT}}/v1/completions,num_concurrent=1000,tokenized_requests=False"

eval_tensor:
	lm_eval \
		--model local-completions \
		--tasks gsm8k \
		--model_args "model={{MODEL_TENSOR}},base_url=http://localhost:{{PORT}}/v1/completions,num_concurrent=1000,tokenized_requests=False"

Test Result

Llama4 Scout

  • cutlass tensor (h100 / b200)
- h100
local-completions (model=nvidia/Llama-4-Scout-17B-16E-Instruct-FP8,base_url=http://localhost:8002/v1/completions,num_concurrent=1000,tokenized_requests=False), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9227|±  |0.0074|
|     |       |strict-match    |     5|exact_match||0.9075|±  |0.0080|
  • trtllm tensor (b200)
local-completions (model=nvidia/Llama-4-Scout-17B-16E-Instruct-FP8,base_url=http://localhost:8000/v1/completions,num_concurrent=1000,tokenized_requests=False), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9242|±  |0.0073|
|     |       |strict-match    |     5|exact_match||0.9075|±  |0.0080|

Qwen3-30B

  • cutlass block (h100)
local-completions (model=Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8,base_url=http://localhost:8001/v1/completions,num_concurrent=1000,tokenized_requests=False), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.8749|±  |0.0091|
|     |       |strict-match    |     5|exact_match||0.8855|±  |0.0088|
  • trtllm block (b200)
local-completions (model=Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8,base_url=http://localhost:8000/v1/completions,num_concurrent=1000,tokenized_requests=False), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.8787|±  |0.0090|
|     |       |strict-match    |     5|exact_match||0.8931|±  |0.0085|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Robert Shaw added 3 commits December 31, 2025 22:32
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
@mergify mergify bot added llama Related to Llama models nvidia labels Jan 1, 2026
Robert Shaw added 6 commits January 1, 2026 00:45
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the quantization configuration for FlashInfer experts, removing a hack and improving code clarity. The changes appear correct and align with the goal of cleaning up the implementation. However, the PR includes several debugging statements (e.g., logger.info, print, timing code) across multiple files. These should be removed before merging to avoid polluting logs and potential performance impacts. Additionally, there is a critical issue in vllm/model_executor/models/llama4.py where a logging statement references a variable defined in a commented-out block, which will cause a NameError at runtime.

I am having trouble creating individual review comments. Click here to see my feedback.

vllm/model_executor/models/llama4.py (524-539)

critical

This block contains commented-out debugging code and an active logging statement that references start, a variable defined within the commented-out section. This will lead to a NameError at runtime. The entire block should be removed.

vllm/model_executor/layers/fused_moe/layer.py (954-956)

high

This logging statement appears to be for debugging purposes and should be removed before merging.

vllm/model_executor/layers/fused_moe/layer.py (1012-1021)

high

This block of code, including the time import and performance logging, seems to be for debugging and should be removed.

vllm/model_executor/layers/fused_moe/layer.py (1274)

high

This logger.info call appears to be for debugging. Please remove it, along with similar debugging logs at lines 1292, 1304, and the commented-out log at line 1366.

vllm/model_executor/models/mllama4.py (1126)

high

This print statement appears to be for debugging. Please remove it, along with the other debug prints in this function at lines 1130, 1136, and 1141.

vllm/model_executor/models/utils.py (198)

high

This logger.info call appears to be for debugging. Please remove it, along with the other debug logs added in this file (lines 219, 255, 264, 292, 302, 334).

Robert Shaw added 5 commits January 1, 2026 14:56
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator Author

now working e2e with fi cutlass

need to make a few more nits for flashinfer trtllm

Robert Shaw added 10 commits January 1, 2026 16:43
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
nit
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Robert Shaw and others added 5 commits January 5, 2026 14:04
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator Author

just reran the quality checks on top the head after the nits, accuracy still looks good

Copy link
Copy Markdown
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making the changes.

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Jan 5, 2026
@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator Author

Thank you for making the changes.

Thanks for your great feedback and review @pavanimajety !

w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
a1_gscale=(1.0 / layer.w13_input_scale),
Copy link
Copy Markdown
Contributor

@amirkl94 amirkl94 Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function is called every forward, which means these 2 lines will result in 2 kernel launches for reciprocal:

a1_gscale=(1.0 / layer.w13_input_scale),
a2_gscale=(1.0 / layer.w2_input_scale),

Can we add these 2 scales in process_weights_after_loading ?

Copy link
Copy Markdown
Collaborator Author

@robertgshaw2-redhat robertgshaw2-redhat Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its not called in the forward pass. I recognize this is confusing, but the apply() method is not called during the forward pass for flashinfer kernels. When flashinfer CUTLASS kernels are selected, the FpMoeMethod is converted into a ModularKernelMethod

I am working on an ongoing refactor that makes the conversion

see https://vllm-dev.slack.com/archives/C08NFPURQ1F/p1767650816469009 for more details on my efforts

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) January 6, 2026 04:34
@github-project-automation github-project-automation bot moved this to Backlog in MoE Refactor Jan 6, 2026
@robertgshaw2-redhat robertgshaw2-redhat moved this from Backlog to Ready in MoE Refactor Jan 6, 2026
@robertgshaw2-redhat robertgshaw2-redhat moved this from Ready to In progress in MoE Refactor Jan 6, 2026
@robertgshaw2-redhat robertgshaw2-redhat moved this from In progress to In review in MoE Refactor Jan 6, 2026
@robertgshaw2-redhat robertgshaw2-redhat merged commit af8fd73 into main Jan 6, 2026
62 checks passed
@robertgshaw2-redhat robertgshaw2-redhat deleted the fix-flashinfer-experts-quant-config-hack branch January 6, 2026 15:47
@github-project-automation github-project-automation bot moved this from In review to Done in MoE Refactor Jan 6, 2026
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Jan 6, 2026
LucasWilkinson pushed a commit to neuralmagic/vllm that referenced this pull request Jan 6, 2026
…#31593)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Jan 9, 2026
…#31593)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
…#31593)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…#31593)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…#31593)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants